The following code loads a dataset on mushroom properties (originally from: http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.names) and fits gradient boosted trees

library(xgboost)
library(DiagrammeR)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')

set.seed(991)

dim(agaricus.train$data)
## [1] 6513  126
bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, 
               max_depth = 2, eta = 1, nthread = 2, nrounds = 2, 
               objective = "binary:logistic")
## [1]  train-logloss:0.233376 
## [2]  train-logloss:0.136658
pred <- predict(bst, agaricus.test$data)

# confusion matrix are useful!
table(Actual = agaricus.test$label, Predicted = pred > 0.5)
##       Predicted
## Actual FALSE TRUE
##      0   813   22
##      1    13  763
  1. What do max_depth=2, eta=1 and nrounds=2 do?
  1. Use xgb.plot.tree to draw the tree (it appears in your browser; you need to export/save it from there)

Let’s first look at our first observation to better understand xgboost

agaricus.test$data[1,]
##                   cap-shape=bell                cap-shape=conical 
##                                1                                0 
##                 cap-shape=convex                   cap-shape=flat 
##                                0                                0 
##                cap-shape=knobbed                 cap-shape=sunken 
##                                0                                0 
##              cap-surface=fibrous              cap-surface=grooves 
##                                0                                0 
##                cap-surface=scaly               cap-surface=smooth 
##                                1                                0 
##                  cap-color=brown                   cap-color=buff 
##                                0                                0 
##               cap-color=cinnamon                   cap-color=gray 
##                                0                                0 
##                  cap-color=green                   cap-color=pink 
##                                0                                0 
##                 cap-color=purple                    cap-color=red 
##                                0                                0 
##                  cap-color=white                 cap-color=yellow 
##                                1                                0 
##                 bruises?=bruises                      bruises?=no 
##                                1                                0 
##                      odor=almond                       odor=anise 
##                                0                                1 
##                    odor=creosote                       odor=fishy 
##                                0                                0 
##                        odor=foul                       odor=musty 
##                                0                                0 
##                        odor=none                     odor=pungent 
##                                0                                0 
##                       odor=spicy         gill-attachment=attached 
##                                0                                0 
##       gill-attachment=descending             gill-attachment=free 
##                                0                                1 
##          gill-attachment=notched               gill-spacing=close 
##                                0                                1 
##             gill-spacing=crowded             gill-spacing=distant 
##                                0                                0 
##                  gill-size=broad                 gill-size=narrow 
##                                1                                0 
##                 gill-color=black                 gill-color=brown 
##                                0                                1 
##                  gill-color=buff             gill-color=chocolate 
##                                0                                0 
##                  gill-color=gray                 gill-color=green 
##                                0                                0 
##                gill-color=orange                  gill-color=pink 
##                                0                                0 
##                gill-color=purple                   gill-color=red 
##                                0                                0 
##                 gill-color=white                gill-color=yellow 
##                                0                                0 
##            stalk-shape=enlarging             stalk-shape=tapering 
##                                1                                0 
##               stalk-root=bulbous                  stalk-root=club 
##                                0                                1 
##                   stalk-root=cup                 stalk-root=equal 
##                                0                                0 
##           stalk-root=rhizomorphs                stalk-root=rooted 
##                                0                                0 
##               stalk-root=missing stalk-surface-above-ring=fibrous 
##                                0                                0 
##   stalk-surface-above-ring=scaly   stalk-surface-above-ring=silky 
##                                0                                0 
##  stalk-surface-above-ring=smooth stalk-surface-below-ring=fibrous 
##                                1                                0 
##   stalk-surface-below-ring=scaly   stalk-surface-below-ring=silky 
##                                0                                0 
##  stalk-surface-below-ring=smooth     stalk-color-above-ring=brown 
##                                1                                0 
##      stalk-color-above-ring=buff  stalk-color-above-ring=cinnamon 
##                                0                                0 
##      stalk-color-above-ring=gray    stalk-color-above-ring=orange 
##                                0                                0 
##      stalk-color-above-ring=pink       stalk-color-above-ring=red 
##                                0                                0 
##     stalk-color-above-ring=white    stalk-color-above-ring=yellow 
##                                1                                0 
##     stalk-color-below-ring=brown      stalk-color-below-ring=buff 
##                                0                                0 
##  stalk-color-below-ring=cinnamon      stalk-color-below-ring=gray 
##                                0                                0 
##    stalk-color-below-ring=orange      stalk-color-below-ring=pink 
##                                0                                0 
##       stalk-color-below-ring=red     stalk-color-below-ring=white 
##                                0                                1 
##    stalk-color-below-ring=yellow                veil-type=partial 
##                                0                                1 
##              veil-type=universal                 veil-color=brown 
##                                0                                0 
##                veil-color=orange                 veil-color=white 
##                                0                                1 
##                veil-color=yellow                 ring-number=none 
##                                0                                0 
##                  ring-number=one                  ring-number=two 
##                                1                                0 
##               ring-type=cobwebby             ring-type=evanescent 
##                                0                                0 
##                ring-type=flaring                  ring-type=large 
##                                0                                0 
##                   ring-type=none                ring-type=pendant 
##                                0                                1 
##              ring-type=sheathing                   ring-type=zone 
##                                0                                0 
##          spore-print-color=black          spore-print-color=brown 
##                                0                                1 
##           spore-print-color=buff      spore-print-color=chocolate 
##                                0                                0 
##          spore-print-color=green         spore-print-color=orange 
##                                0                                0 
##         spore-print-color=purple          spore-print-color=white 
##                                0                                0 
##         spore-print-color=yellow              population=abundant 
##                                0                                0 
##             population=clustered              population=numerous 
##                                0                                0 
##             population=scattered               population=several 
##                                1                                0 
##              population=solitary                  habitat=grasses 
##                                0                                0 
##                   habitat=leaves                  habitat=meadows 
##                                0                                1 
##                    habitat=paths                    habitat=urban 
##                                0                                0 
##                    habitat=waste                    habitat=woods 
##                                0                                0

Demonstration of first observation on xgboost…

knitr::include_graphics("data/boosting_working.png")

Same answer as…

pred[1]
## [1] 0.2858302

Plotting Tree

xgb.plot.tree(model = bst)
  1. Fit a model with the same options for max_depth=2 and eta=1 but with nrounds chosen to minimise cross-validation loss. Use xgb.plot.tree to plot it. Comment on the relative accuracy and complexity of the two models
xgb_best = xgb.cv(data = agaricus.train$data, label = agaricus.train$label, 
               max_depth = 2, eta = 1, nthread = 2, nrounds = 30, nfold = 5,
               objective = "binary:logistic", metrics = "error")
## [1]  train-error:0.046522+0.000911   test-error:0.046522+0.003644 
## [2]  train-error:0.022263+0.000676   test-error:0.022263+0.002704 
## [3]  train-error:0.007063+0.000255   test-error:0.007063+0.001019 
## [4]  train-error:0.015200+0.000477   test-error:0.015201+0.001909 
## [5]  train-error:0.007063+0.000255   test-error:0.007063+0.001019 
## [6]  train-error:0.001689+0.000989   test-error:0.002303+0.002003 
## [7]  train-error:0.001228+0.000153   test-error:0.001228+0.000614 
## [8]  train-error:0.001228+0.000153   test-error:0.001228+0.000614 
## [9]  train-error:0.001152+0.000172   test-error:0.001228+0.000614 
## [10] train-error:0.001152+0.000172   test-error:0.001228+0.000614 
## [11] train-error:0.000960+0.000500   test-error:0.001075+0.000783 
## [12] train-error:0.000422+0.000521   test-error:0.000767+0.000971 
## [13] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [14] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [15] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [16] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [17] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [18] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [19] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [20] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [21] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [22] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [23] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [24] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [25] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [26] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [27] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [28] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [29] train-error:0.000000+0.000000   test-error:0.000000+0.000000 
## [30] train-error:0.000000+0.000000   test-error:0.000000+0.000000

Note: the loss is “logloss”, we can use the metric “error” to measure the prediction error for binary classification

The best model appears to be run….. FINISH THIS

xgb.cv(data = agaricus.train$data, label = agaricus.train$label, 
               max_depth = 2, eta = 1, nthread = 2, nrounds = 3, nfold = 5,
               objective = "binary:logistic", metrics = "error")
## [1]  train-error:0.046522+0.001388   test-error:0.046523+0.005555 
## [2]  train-error:0.022263+0.001196   test-error:0.022264+0.004784 
## [3]  train-error:0.007063+0.000561   test-error:0.007062+0.002245
bst2 = xgboost(data = agaricus.train$data, label = agaricus.train$label, 
               max_depth = 2, eta = 1, nthread = 2, nrounds = 3, 
               objective = "binary:logistic")
## [1]  train-logloss:0.233376 
## [2]  train-logloss:0.136658 
## [3]  train-logloss:0.082531
xgb.plot.tree(model = bst2)
  1. Now try lowering the learning rate eta to reduce cross-validation loss. (think about a strategy for choosing values of eta to try, but don’t try more than five or so different ones)
for (i in c(0.1, 0.25, 0.5, 0.75, 1.0)) {
  print(i)
  xgb.cv(data = agaricus.train$data, label = agaricus.train$label, 
                 max_depth = 2, eta = i, nthread = 2, nrounds = 10, nfold = 5,
                 objective = "binary:logistic", metrics = "error")
}
## [1] 0.1
## [1]  train-error:0.046522+0.001003   test-error:0.046521+0.004010 
## [2]  train-error:0.042569+0.001180   test-error:0.042683+0.005848 
## [3]  train-error:0.046522+0.001003   test-error:0.046521+0.004010 
## [4]  train-error:0.041609+0.000971   test-error:0.041608+0.003886 
## [5]  train-error:0.041609+0.000971   test-error:0.041608+0.003886 
## [6]  train-error:0.041609+0.000971   test-error:0.041608+0.003886 
## [7]  train-error:0.037694+0.007680   test-error:0.038997+0.007148 
## [8]  train-error:0.030554+0.008559   test-error:0.031014+0.011339 
## [9]  train-error:0.041609+0.000971   test-error:0.041608+0.003886 
## [10] train-error:0.023338+0.000782   test-error:0.023338+0.003130 
## [1] 0.25
## [1]  train-error:0.046522+0.001022   test-error:0.046523+0.004091 
## [2]  train-error:0.046522+0.001022   test-error:0.046523+0.004091 
## [3]  train-error:0.023338+0.000415   test-error:0.023338+0.001661 
## [4]  train-error:0.041609+0.000793   test-error:0.041610+0.003172 
## [5]  train-error:0.009443+0.006902   test-error:0.009519+0.007306 
## [6]  train-error:0.015200+0.004545   test-error:0.015354+0.003435 
## [7]  train-error:0.013473+0.002327   test-error:0.013818+0.001813 
## [8]  train-error:0.018693+0.002386   test-error:0.019653+0.003904 
## [9]  train-error:0.019883+0.001061   test-error:0.020728+0.003730 
## [10] train-error:0.020881+0.001198   test-error:0.020881+0.003481 
## [1] 0.5
## [1]  train-error:0.046522+0.001308   test-error:0.046521+0.005233 
## [2]  train-error:0.045179+0.001532   test-error:0.046061+0.005590 
## [3]  train-error:0.023530+0.001287   test-error:0.024874+0.003315 
## [4]  train-error:0.028904+0.003929   test-error:0.029939+0.007388 
## [5]  train-error:0.013665+0.003770   test-error:0.013359+0.003783 
## [6]  train-error:0.016582+0.001345   test-error:0.014892+0.005835 
## [7]  train-error:0.003032+0.002236   test-error:0.002918+0.001489 
## [8]  train-error:0.008560+0.003179   test-error:0.007522+0.004214 
## [9]  train-error:0.001996+0.000260   test-error:0.001996+0.001041 
## [10] train-error:0.001689+0.000444   test-error:0.001689+0.001128 
## [1] 0.75
## [1]  train-error:0.046522+0.000978   test-error:0.046523+0.003913 
## [2]  train-error:0.039230+0.010643   test-error:0.039611+0.010708 
## [3]  train-error:0.026447+0.005216   test-error:0.026869+0.003434 
## [4]  train-error:0.016122+0.007104   test-error:0.016735+0.006854 
## [5]  train-error:0.010095+0.002944   test-error:0.009674+0.003353 
## [6]  train-error:0.002149+0.000818   test-error:0.001689+0.000895 
## [7]  train-error:0.002495+0.000993   test-error:0.002149+0.001128 
## [8]  train-error:0.001689+0.000352   test-error:0.001689+0.001128 
## [9]  train-error:0.001651+0.000464   test-error:0.001842+0.000921 
## [10] train-error:0.001420+0.000773   test-error:0.001535+0.001284 
## [1] 1
## [1]  train-error:0.051014+0.009867   test-error:0.054350+0.013054 
## [2]  train-error:0.021188+0.002272   test-error:0.021649+0.005031 
## [3]  train-error:0.009788+0.005480   test-error:0.010593+0.007106 
## [4]  train-error:0.014087+0.002638   test-error:0.014585+0.002950 
## [5]  train-error:0.005950+0.002231   test-error:0.006449+0.001855 
## [6]  train-error:0.001305+0.000144   test-error:0.001689+0.001128 
## [7]  train-error:0.001228+0.000094   test-error:0.001228+0.000376 
## [8]  train-error:0.001305+0.000144   test-error:0.001689+0.001128 
## [9]  train-error:0.001228+0.000094   test-error:0.001228+0.000376 
## [10] train-error:0.000461+0.000564   test-error:0.000614+0.000753

We looks at smallest errors (dips) of each learning rate to determine the number of iterations/runs we need. Learning rate of 1.0 requires just 3 runs, whereas unsurprisingly the minimum number of runs increases as the learning rate decreases. Learning rate of 0.1 requires 7 runs to get the min error.

  1. Data wrangling: the file mushroom.test contains descriptions of three new mushrooms. How does the first model classify their edibility? To convert the new data into the correct matrix form, you will need to construct column names as they are in the main data set. The names from the main data can be retrieved using dimnames(agaricus.train$data)[[2]].
# just draw diagram
knitr::include_graphics("data/boosting_new_working.jpg")

  1. The mushrooms are A: Amanita phalloides, B: Amanita virosa, C: Volvariella volvacea. Look up their common names. Comment on the usefulness of the model.

After some research A and B are actually poisonous, however they’re both predicted as being not poisonous which is incorrect. C is predicted as non-poisonous and it is non-poisonous in reality which is correct. Not a useful model, probably trained on species in America as opposed to New Zealand.